{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Data Augmentation for Geospatial Training\n",
    "\n",
    "[![image](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/opengeos/geoai/blob/main/docs/examples/data_augmentation.ipynb)\n",
    "\n",
    "This notebook demonstrates how to use data augmentation when preparing training tiles and training segmentation models. Data augmentation helps improve model generalization by creating variations of training data through transformations like flips, rotations, and photometric adjustments.\n",
    "\n",
    "## Key Features\n",
    "\n",
    "- **Tile Export with Augmentation**: Generate augmented versions of tiles during export\n",
    "- **Default Augmentation Transforms**: Use pre-configured transforms optimized for remote sensing\n",
    "- **Custom Augmentation**: Define your own augmentation pipeline\n",
    "- **Enhanced Training Defaults**: Improved default augmentations for segmentation model training\n",
    "\n",
    "## Install package\n",
    "\n",
    "Uncomment the following line to install the geoai package if needed:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# %pip install -U geoai"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Import libraries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import geoai\n",
    "import numpy as np\n",
    "import rasterio\n",
    "import matplotlib.pyplot as plt\n",
    "from pathlib import Path"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Download sample data\n",
    "\n",
    "We'll use the same NAIP imagery and building footprint dataset as used in the [train_segmentation_model.ipynb](https://opengeoai.org/examples/train_segmentation_model) example. This is real aerial imagery with building annotations."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Download NAIP imagery and building footprints\n",
    "train_raster_url = (\n",
    "    \"https://huggingface.co/datasets/giswqs/geospatial/resolve/main/naip_rgb_train.tif\"\n",
    ")\n",
    "train_vector_url = \"https://huggingface.co/datasets/giswqs/geospatial/resolve/main/naip_train_buildings.geojson\"\n",
    "\n",
    "sample_image = geoai.download_file(train_raster_url)\n",
    "sample_vector = geoai.download_file(train_vector_url)\n",
    "\n",
    "print(f\"Downloaded sample image: {sample_image}\")\n",
    "print(f\"Downloaded sample labels: {sample_vector}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Visualize sample data\n",
    "\n",
    "Let's visualize the NAIP imagery and building footprints."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"Image information:\")\n",
    "geoai.get_raster_info(sample_image)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Visualize on interactive map\n",
    "geoai.view_vector_interactive(sample_vector, tiles=sample_image)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Part 1: Export Tiles WITHOUT Augmentation (Baseline)\n",
    "\n",
    "First, let's export tiles without augmentation to establish a baseline."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create output directory\n",
    "output_dir = \"buildings_augmentation_demo\"\n",
    "\n",
    "# Export without augmentation\n",
    "output_no_aug = f\"{output_dir}/tiles_no_augmentation\"\n",
    "\n",
    "geoai.export_geotiff_tiles(\n",
    "    sample_image,\n",
    "    output_no_aug,\n",
    "    in_class_data=sample_vector,\n",
    "    tile_size=256,\n",
    "    stride=128,\n",
    "    apply_augmentation=False,  # No augmentation\n",
    ")\n",
    "\n",
    "# Count tiles\n",
    "image_tiles = list(Path(output_no_aug, \"images\").glob(\"*.tif\"))\n",
    "label_tiles = list(Path(output_no_aug, \"labels\").glob(\"*.tif\"))\n",
    "\n",
    "print(f\"\\nWithout augmentation:\")\n",
    "print(f\"  Image tiles: {len(image_tiles)}\")\n",
    "print(f\"  Label tiles: {len(label_tiles)}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Part 2: Export Tiles WITH Default Augmentation\n",
    "\n",
    "Now let's export tiles with default augmentation. The `get_default_augmentation_transforms()` function provides sensible defaults for remote sensing data:\n",
    "\n",
    "**Geometric Transforms:**\n",
    "- Horizontal/Vertical Flips (50% probability each)\n",
    "- Random 90° Rotations (50% probability)\n",
    "- Shift-Scale-Rotate (50% probability)\n",
    "\n",
    "**Photometric Transforms:**\n",
    "- Random Brightness/Contrast (50% probability)\n",
    "- HSV Color Adjustments (30% probability)\n",
    "- Gaussian Noise (20% probability)\n",
    "- Gaussian Blur (20% probability)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Export with default augmentation - generate 3 augmented versions per tile\n",
    "output_with_aug = f\"{output_dir}/tiles_with_augmentation\"\n",
    "\n",
    "geoai.export_geotiff_tiles(\n",
    "    sample_image,\n",
    "    output_with_aug,\n",
    "    in_class_data=sample_vector,\n",
    "    tile_size=256,\n",
    "    stride=128,\n",
    "    apply_augmentation=True,  # Enable augmentation\n",
    "    augmentation_count=3,  # Generate 3 augmented versions per tile\n",
    ")\n",
    "\n",
    "# Count tiles\n",
    "aug_image_tiles = list(Path(output_with_aug, \"images\").glob(\"*.tif\"))\n",
    "aug_label_tiles = list(Path(output_with_aug, \"labels\").glob(\"*.tif\"))\n",
    "\n",
    "print(f\"\\nWith augmentation (3 per tile):\")\n",
    "print(f\"  Image tiles: {len(aug_image_tiles)} (original + augmented)\")\n",
    "print(f\"  Label tiles: {len(aug_label_tiles)} (original + augmented)\")\n",
    "print(f\"  \\nThis is {len(aug_image_tiles) / len(image_tiles):.1f}x more training data!\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Visualize Original vs Augmented Tiles\n",
    "\n",
    "Let's compare an original tile with its augmented versions to see the transformations."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load original tile and some augmented versions\n",
    "tile_files = sorted([f for f in Path(output_with_aug, \"images\").glob(\"*.tif\")])\n",
    "label_files = sorted([f for f in Path(output_with_aug, \"labels\").glob(\"*.tif\")])\n",
    "\n",
    "# Get first 4 tiles (1 original + 3 augmented)\n",
    "n_display = min(4, len(tile_files))\n",
    "\n",
    "fig, axes = plt.subplots(2, n_display, figsize=(15, 8))\n",
    "if n_display == 1:\n",
    "    axes = axes.reshape(2, 1)\n",
    "\n",
    "for i in range(n_display):\n",
    "    # Load and display image tile\n",
    "    with rasterio.open(tile_files[i]) as src:\n",
    "        img = src.read()  # Read all bands\n",
    "        # If RGB, display as color\n",
    "        if img.shape[0] >= 3:\n",
    "            img_display = np.transpose(img[:3], (1, 2, 0))\n",
    "            axes[0, i].imshow(img_display)\n",
    "        else:\n",
    "            axes[0, i].imshow(img[0], cmap=\"gray\")\n",
    "\n",
    "        title = \"Original\" if i == 0 else f\"Augmented {i}\"\n",
    "        axes[0, i].set_title(f\"{title}\\n{tile_files[i].name}\")\n",
    "        axes[0, i].axis(\"off\")\n",
    "\n",
    "    # Load and display label tile\n",
    "    with rasterio.open(label_files[i]) as src:\n",
    "        label = src.read(1)\n",
    "        axes[1, i].imshow(label, cmap=\"tab10\", vmin=0, vmax=10)\n",
    "        axes[1, i].set_title(f\"Label\\n{label_files[i].name}\")\n",
    "        axes[1, i].axis(\"off\")\n",
    "\n",
    "fig.suptitle(\"Original Tile vs Augmented Versions\", fontsize=14, y=1.02)\n",
    "plt.tight_layout()\n",
    "plt.show()\n",
    "\n",
    "print(\n",
    "    \"Notice how the augmented tiles have different orientations, colors, and brightness\"\n",
    ")\n",
    "print(\"while the labels are transformed consistently with the images.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Part 3: Custom Augmentation Pipeline\n",
    "\n",
    "You can also define your own custom augmentation transforms using albumentations. This is useful when you want specific augmentations for your use case."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import albumentations as A\n",
    "\n",
    "# Define custom augmentation pipeline\n",
    "custom_transforms = A.Compose(\n",
    "    [\n",
    "        A.HorizontalFlip(p=1.0),  # Always flip horizontally\n",
    "        A.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3, p=0.8),\n",
    "        A.HueSaturationValue(\n",
    "            hue_shift_limit=20, sat_shift_limit=30, val_shift_limit=20, p=0.5\n",
    "        ),\n",
    "    ]\n",
    ")\n",
    "\n",
    "# Export with custom augmentation\n",
    "output_custom_aug = f\"{output_dir}/tiles_custom_augmentation\"\n",
    "\n",
    "geoai.export_geotiff_tiles(\n",
    "    sample_image,\n",
    "    output_custom_aug,\n",
    "    in_class_data=sample_vector,\n",
    "    tile_size=256,\n",
    "    stride=128,\n",
    "    apply_augmentation=True,\n",
    "    augmentation_count=2,\n",
    "    augmentation_transforms=custom_transforms,  # Use custom transforms\n",
    ")\n",
    "\n",
    "custom_image_tiles = list(Path(output_custom_aug, \"images\").glob(\"*.tif\"))\n",
    "print(f\"\\nWith custom augmentation (2 per tile):\")\n",
    "print(f\"  Image tiles: {len(custom_image_tiles)}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Part 4: Using get_default_augmentation_transforms()\n",
    "\n",
    "You can also access the default augmentation transforms directly to use in your own workflows."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from geoai.utils import get_default_augmentation_transforms\n",
    "\n",
    "# Get default transforms\n",
    "default_transforms = get_default_augmentation_transforms(\n",
    "    tile_size=256, include_normalize=False  # Don't normalize for visualization\n",
    ")\n",
    "\n",
    "print(\"Default augmentation pipeline:\")\n",
    "print(default_transforms)\n",
    "\n",
    "# Apply to a sample image\n",
    "with rasterio.open(tile_files[0]) as src:\n",
    "    original_img = src.read()\n",
    "    # Convert to HWC format for albumentations\n",
    "    img_hwc = np.transpose(original_img, (1, 2, 0))\n",
    "\n",
    "# Apply augmentation\n",
    "augmented = default_transforms(image=img_hwc)\n",
    "aug_img = augmented[\"image\"]\n",
    "\n",
    "# Visualize\n",
    "fig, axes = plt.subplots(1, 2, figsize=(10, 5))\n",
    "axes[0].imshow(img_hwc)\n",
    "axes[0].set_title(\"Original\")\n",
    "axes[0].axis(\"off\")\n",
    "\n",
    "axes[1].imshow(aug_img)\n",
    "axes[1].set_title(\"Augmented (random transform)\")\n",
    "axes[1].axis(\"off\")\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Part 5: Training with Enhanced Default Augmentation\n",
    "\n",
    "The `train_segmentation_model()` function now uses improved default augmentations that include:\n",
    "- Horizontal and vertical flips\n",
    "- Random 90° rotations\n",
    "- Brightness and contrast adjustments\n",
    "\n",
    "Here's an example of how you would train with these defaults (we won't actually run training in this demo):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from geoai.train import train_segmentation_model\n",
    "\n",
    "# Train with enhanced default augmentation\n",
    "model = train_segmentation_model(\n",
    "    images_dir=f\"{output_dir}/tiles_with_augmentation/images\",\n",
    "    labels_dir=f\"{output_dir}/tiles_with_augmentation/labels\",\n",
    "    output_dir=f\"{output_dir}/tiles_with_augmentation/training_output\",\n",
    "    architecture=\"unet\",\n",
    "    encoder_name=\"resnet34\",\n",
    "    num_classes=4,\n",
    "    batch_size=8,\n",
    "    num_epochs=20,\n",
    "    # The following default augmentations are applied automatically:\n",
    "    # - Horizontal flips (50%)\n",
    "    # - Vertical flips (50%)\n",
    "    # - Random 90° rotations (50%)\n",
    "    # - Brightness adjustment (50%)\n",
    "    # - Contrast adjustment (50%)\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "geoai.plot_performance_metrics(\n",
    "    history_path=f\"{output_dir}/tiles_with_augmentation/training_output/training_history.pth\",\n",
    "    figsize=(15, 5),\n",
    "    verbose=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Part 6: Custom Training Transforms\n",
    "\n",
    "If you want to use custom augmentations during training, you can define your own transform functions and pass them to `train_segmentation_model()`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from geoai.train import (\n",
    "    train_segmentation_model,\n",
    "    SemanticTransforms,\n",
    "    SemanticRandomHorizontalFlip,\n",
    "    SemanticToTensor,\n",
    ")\n",
    "\n",
    "# Define custom training transforms\n",
    "custom_train_transforms = SemanticTransforms(\n",
    "    [\n",
    "        SemanticToTensor(),\n",
    "        SemanticRandomHorizontalFlip(0.5),\n",
    "        # Add more custom transforms here...\n",
    "    ]\n",
    ")\n",
    "\n",
    "# Train with custom augmentation\n",
    "model = train_segmentation_model(\n",
    "    images_dir=f\"{output_dir}/tiles_no_augmentation/images\",\n",
    "    labels_dir=f\"{output_dir}/tiles_no_augmentation/labels\",\n",
    "    output_dir=f\"{output_dir}/tiles_no_augmentation/training_output\",\n",
    "    architecture=\"unet\",\n",
    "    num_classes=4,\n",
    "    train_transforms=custom_train_transforms,  # Use custom transforms\n",
    "    num_epochs=20,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "geoai.plot_performance_metrics(\n",
    "    history_path=f\"{output_dir}/tiles_no_augmentation/training_output/training_history.pth\",\n",
    "    figsize=(15, 5),\n",
    "    verbose=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Summary\n",
    "\n",
    "This notebook demonstrated:\n",
    "\n",
    "1. **Tile Export with Augmentation**: Using `export_geotiff_tiles()` with `apply_augmentation=True` to generate augmented training data\n",
    "2. **Default Augmentation**: Using `get_default_augmentation_transforms()` for sensible defaults optimized for remote sensing\n",
    "3. **Custom Augmentation**: Defining custom augmentation pipelines with albumentations\n",
    "4. **Training with Augmentation**: How the enhanced defaults work in `train_segmentation_model()`\n",
    "\n",
    "### Key Benefits of Data Augmentation:\n",
    "\n",
    "- **More Training Data**: Generate 2-5x more training samples from existing data\n",
    "- **Better Generalization**: Models learn to handle variations in orientation, lighting, and appearance\n",
    "- **Reduced Overfitting**: More diverse training data helps prevent memorization\n",
    "- **Improved Accuracy**: Typically results in 2-5% better validation accuracy\n",
    "\n",
    "### Best Practices:\n",
    "\n",
    "1. Start with default augmentation - it works well for most remote sensing tasks\n",
    "2. Use 2-5 augmented versions per tile (more isn't always better)\n",
    "3. Ensure augmentations match your domain (e.g., avoid vertical flips if imagery has a consistent \"up\" direction)\n",
    "4. Monitor validation performance to ensure augmentations help rather than hurt\n",
    "\n",
    "For more information, see the [geoai documentation](https://opengeoai.org)."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "geo",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
